#!/usr/bin/python3
# -*- coding:UTF-8 -*-

import AutoExecUtils
import os
import traceback
import datetime
import argparse
import sys
import json
from bson.json_util import dumps, loads
from DSLTools import Parser, Interpreter

# inspect defined for app system cache
appSystemParsedCache = {}


def usage():
    pname = os.path.basename(__file__)
    print(pname + " --outputfile <path> ")
    exit(1)


def getObjCatPkDef(objCat, customPkDef):
    pkDefs = {
        'INS': ['MGMT_IP',       'PORT'],
        'DB': ['PRIMARY_IP', 'PORT', 'NAME'],
        'CLUSTER': ['PRIMARY_IP', 'PORT', 'NAME'],
        'DBINS': ['MGMT_IP', 'PORT', 'INSTANCE_NAME'],
        'OS': ['MGMT_IP'],
        'HOST': ['MGMT_IP', 'BOARD_SERIAL'],
        'NETDEV': ['MGMT_IP', 'SN'],
        'SECDEV': ['MGMT_IP', 'SN'],
        'SERVERDEV': ['MGMT_IP'],
        'VIRTUALIZED': ['MGMT_IP'],
        'SWITCH': ['MGMT_IP', 'SN'],
        'FIREWALL': ['MGMT_IP', 'SN'],
        'LOADBALANCER': ['MGMT_IP', 'SN'],
        'STORAGE': ['MGMT_IP', 'SN'],
        'FCSWITCH': ['MGMT_IP', 'SN'],
        'K8S': ['MGMT_IP'],
        'CONTAINER': ['MGMT_IP', 'CONTAINER_ID'],
        'UNKNOWN': ['MGMT_IP']
    }
    fixPkDef = pkDefs.get(objCat)
    if fixPkDef is not None:
        return fixPkDef

    return customPkDef


def getLevelNo(level):
    levelNo = 0
    if level.upper() == 'WARN':
        levelNo = 1
    elif level.upper() == 'CRITICAL':
        levelNo = 2
    elif level.upper() == 'FATAL':
        levelNo = 3
    return levelNo


def getAlertSummarized(alerts):
    totalCount = 0
    gLevel = 'NORMAL'
    alertOutline = {}
    alertFields = {}
    for alert in alerts:
        totalCount = totalCount + 1
        level = alert['ruleLevel']
        if level not in alertOutline:
            alertOutline[level] = 1
        else:
            alertOutline[level] = alertOutline[level] + 1
        if getLevelNo(gLevel) < getLevelNo(level):
            gLevel = level

        jsonPath = alert['jsonPath']
        alertField = None
        if jsonPath not in alertFields:
            alertField = {'alertLevel': alert['ruleLevel'],
                          'alertField': jsonPath,
                          'ruleSeqs': [alert['ruleSeq']],
                          'ruleNames': [alert['ruleName']]}
            alertFields[jsonPath] = alertField
        else:
            alertField = alertFields[jsonPath]
            ruleLevel = alert['ruleLevel']
            if getLevelNo(alertField['alertLevel']) < getLevelNo(ruleLevel):
                alertField['alertLevel'] = ruleLevel
            alertField['ruleSeqs'].append(alert['ruleSeq'])
            alertField['ruleNames'].append(alert['ruleName'])

    alertFieldsArray = list(alertFields.values())
    alertSummarized = {}
    alertSummarized['totalCount'] = totalCount
    alertSummarized['status'] = gLevel

    return {"alertOutline": alertOutline,
            'alertSummarized': alertSummarized,
            'alertFields': alertFieldsArray}


def getAppMetrisDef(db, appSysIds, defName, parentRuleMap):
    if not appSysIds:
        return []

    newAppDefs = []
    queryAppSysIds = []
    for appSysId in appSysIds:
        cachedAppDef = appSystemParsedCache.get(appSysId)
        if cachedAppDef is not None:
            newAppDefs.append(cachedAppDef)
        else:
            queryAppSysIds.append(appSysId)

    appDefCollection = db['_inspectdef_app']
    # find appsystem custome threshold
    appDefs = []
    if queryAppSysIds:
        appDefs = appDefCollection.find({'appSystemId': {'$in': appSysIds}, 'name': defName})

    for appDef in appDefs:
        # 比较AST抽象语法树结构生成时间根规则修改时间，如果老了则进行AST分析并保存和更新到DB
        ruleSeq = 5000
        ruleNeedUpdate = False
        currentTime = datetime.datetime.utcnow()
        for ruleDef in appDef['thresholds']:
            if ruleDef.get('isOverWrite') != 1:
                ruleSeq = ruleSeq + 1
                ruleDef['ruleSeq'] = '%s#%d' % (appDef.get('appSystemAbbrName'), ruleSeq)
            else:
                parentRule = parentRuleMap.get(ruleDef['ruleUuid'])
                ruleDef['ruleSeq'] = parentRule.get('ruleSeq')
            ruleDef['appSystemId'] = appDef['appSystemId']

            updateTime = None
            if updateTime in ruleDef:
                ruleDef['_updatetime']
            else:
                updateTime = currentTime

            astGenTime = None
            if astGenTime in ruleDef:
                astGenTime = ruleDef['_astgentime']

            if astGenTime is None or updateTime is None or astGenTime < updateTime:
                ruleDef['_astgentime'] = currentTime
                ast = Parser(ruleDef['rule'])
                ruleDef['AST'] = ast.asList()
                # print(json.dumps(ast.asList(), sort_keys=True, indent=4))
                ruleNeedUpdate = True

        if ruleNeedUpdate:
            _id = appDef['_id']
            del(appDef['_id'])
            appDefCollection.update_one({'_id': _id},
                                        {'$set':
                                         {'thresholds': appDef['thresholds']}
                                         })
        newAppDefs.append(appDef)

    return newAppDefs


def getMetricsDef(db, objCat, objType):
    inspectDefCollection = db['_inspectdef']
    collectName = 'COLLECT_' + objCat.upper()
    inspectDef = inspectDefCollection.find_one({'collection': collectName,
                                                'filter._OBJ_TYPE': objType}
                                               )
    if inspectDef is None:
        inspectDef = inspectDefCollection.find_one({'collection': collectName,
                                                    'filter._OBJ_TYPE': None}
                                                   )

    if inspectDef is None:
        return {'name': 'empty',
                'fields': [],
                'thresholds': []}

    # 比较AST抽象语法树结构生成时间根规则修改时间，如果老了则进行AST分析并保存和更新到DB
    ruleMap = {}
    ruleSeq = 0
    ruleNeedUpdate = False
    currentTime = datetime.datetime.utcnow()
    for ruleDef in inspectDef['thresholds']:
        ruleSeq = ruleSeq + 1
        ruleDef['ruleSeq'] = '#%d' % (ruleSeq)
        ruleMap[ruleDef.get('ruleUuid', 0)] = ruleDef

        updateTime = None
        if updateTime in ruleDef:
            ruleDef['_updatetime']
        else:
            updateTime = currentTime

        astGenTime = None
        if astGenTime in ruleDef:
            astGenTime = ruleDef['_astgentime']

        if astGenTime is None or updateTime is None or astGenTime < updateTime:
            ruleDef['_astgentime'] = currentTime
            ast = Parser(ruleDef['rule'])
            ruleDef['AST'] = ast.asList()
            # print(json.dumps(ast.asList(), sort_keys=True, indent=4))
            ruleNeedUpdate = True

    if ruleNeedUpdate:
        _id = inspectDef['_id']
        del(inspectDef['_id'])
        inspectDefCollection.update_one({'_id': _id},
                                        {'$set':
                                         {'thresholds': inspectDef['thresholds']}
                                         })

    return inspectDef

# 获取对应CMDB对象全量数据


def getObjCatFullData(datalist):
    new_datalist = []
    for data in datalist:
        objCat = data.get('_OBJ_CATEGORY')
        objType = data.get('_OBJ_TYPE')
        pkDef = data.get('PK')
        pkDef = getObjCatPkDef(objCat, pkDef)
        table = "COLLECT_" + objCat

        typeName = objCat
        if objType is not None:
            typeName = objType

        pkInvalid = False
        # 根据PK的定义生成Primary Key filter
        primaryKey = {}
        for pKey in pkDef:
            if pKey in data:
                pVal = data[pKey]
                primaryKey[pKey] = data[pKey]
                if pVal is None or pVal == '':
                    print("WARN: {} PK attribute:{} is empty.".format(typeName, pKey))
            else:
                primaryKey[pKey] = None
                pkInvalid = True
                print("WARN: {} not contain PK attribute:{}.".format(typeName, pKey))

        if pkInvalid:
            new_datalist.append(data)

        collection = db[table]
        fullData = collection.find_one(primaryKey, {'_id': False})
        if fullData is None:
            fullData = data
        else:
            for key in data.keys():
                fullData[key] = data[key]
        new_datalist.append(fullData)
    return new_datalist


def getInspectData(nodeInfo):
    ip = nodeInfo.get('host')
    port = nodeInfo.get('port')
    nodeName = nodeInfo.get('nodeName')

    dataRecords = []
    inspectData = {}
    nodeData = {}
    # 加载上一个cmdb收集步骤的output数据，CMDB采集可能返回多条记录，需要抽取匹配的记录
    outputData = AutoExecUtils.loadNodeOutput()
    if outputData is None:
        print("WARN: Node output data is empty.")
    else:
        collectData = []
        for key, value in outputData.items():
            if key.startswith("inspect/"):
                # 这一步是抽取可用性检测插件（pingcheck,urlcheck,sshcheck,agentcheck,tcpcheck）的返回结果
                aInspectData = value.get('DATA')
                if aInspectData is not None:
                    inspectData.update(aInspectData)
                continue
            elif key.startswith("svcinspect/"):
                aInspectData = value.get('DATA')
                # svcinspect服务检测插件，包含CMDB采集数据（数据）和可用性检测结果（Map对象），针对svcinspect特殊处理
                if isinstance(value.get('DATA'), list):
                    aInspectData = aInspectData[0]

                if aInspectData is not None:
                    inspectData.update(aInspectData)

                if not isinstance(value.get('DATA'), list):
                    continue

                # svcinspect插件是对CMDB采集插件数据的补充，下面告警阀值是对match全量采集数据的，这里对增量的采集数据进行数据补全
                value['DATA'] = getObjCatFullData(value.get('DATA'))

            elif not key.startswith("cmdbcollect/"):
                continue

            collectData = value['DATA']
            if collectData is None:
                print('WARN: Plugin {} did not return collect data.'.format(key))
                continue

        if not collectData:
            print("WARN: Can not find collect data in node output.")

        if 'AVAILABILITY' in inspectData:
            if inspectData['AVAILABILITY'] == 0:
                inspectData['AVAILABILITY'] = False
            else:
                inspectData['AVAILABILITY'] = True
        else:
            inspectData['AVAILABILITY'] = False
            errMsg = 'AVAILABILITY attribute not found in inspect data, must use availability tools(pingcheck,urlcheck,sshcheck,agentcheck,tcpcheck...) to collect availability first.'
            inspectData['ERROR_MESSAGE'] = errMsg

        # 标记数据集合类型
        objCatMap = {}
        for data in collectData:
            objCat = data.get('_OBJ_CATEGORY')
            objCatMap[objCat] = 1

        for data in collectData:
            isMalformData = 0
            objCat = data.get('_OBJ_CATEGORY')
            if objCat is None:
                isMalformData = 1
                print('WARN: Data not defined _OBJ_CATEGORY.')
            if isMalformData == 1:
                print(json.dumps(data))
                continue

            # 如果数据的归属类别跟定义和节点信息一致，则抽取并返回数据记录
            if data.get('MGMT_IP') == ip or data.get('PRIMARY_IP') == ip:
                if port is not None:
                    if 'PORT' not in data:
                        print('WARN: Data has no port.')
                        continue

                    if int(data['PORT']) != port:
                        print('WARN: Data not match port:{} but:{}.'.format(port, data['PORT']))
                        continue

                if objCat == 'DBINS' and data.get('INSTANCE_NAME') != nodeName:
                    continue
                elif objCat == 'DB' and data.get('NAME') != nodeName:
                    continue
                elif 'OS' in objCatMap and objCat == 'HOST':
                    continue

                print("INFO: Data matched {}/{}:{}/{}.".format(objCat, ip, port, nodeName))
                nodeData = data
                if 'RESOURCE_ID' in nodeData:
                    inspectData['RESOURCE_ID'] = nodeData['RESOURCE_ID']
                nodeData.update(inspectData)
                dataRecords.append(nodeData)

            else:
                print('WARN: Data not match IP:{} but:{}.'.format(ip, data['MGMT_IP']))

        if not dataRecords:
            print("WARN: Inspect data is empty, set _OBJ_CATEGORY to EMPTY.")
            nodeData = inspectData
            nodeData['_OBJ_CATEGORY'] = 'EMPTY'
            nodeData['_OBJ_TYPE'] = 'EMPTY'
            dataRecords.append(nodeData)

        return dataRecords


def inspectData(nodeInfo):
    rptDataRecords = []
    execUser = os.getenv('AUTOEXEC_USER')
    inspectDefMap = {}
    objCatFieldsMap = {}
    # 获取cmdb对单个节点采集的数据
    dataRecords = getInspectData(nodeInfo)
    for dataCollected in dataRecords:
        if dataCollected is not None:
            inspectDef = {}
            inspectRulesMap = {}

            availability = dataCollected['AVAILABILITY']

            objCat = dataCollected['_OBJ_CATEGORY']
            objType = dataCollected['_OBJ_TYPE']
            if objCat not in inspectDefMap:
                inspectDef = getMetricsDef(db, objCat, objType)
                inspectDefMap[objCat] = inspectDef

                needFieldsMap = {
                    'AVAILABILITY': 1,
                    'ERROR_MESSAGE': 1,
                    'RESPONSE_TIME': 1
                }
                objCatFieldsMap[objCat] = needFieldsMap
                for fieldDef in inspectDef['fields']:
                    if fieldDef['selected'] == 1:
                        needFieldsMap[fieldDef['name']] = 1

            resourceId = dataCollected.get('RESOURCE_ID')
            if resourceId is None or resourceId == '0':
                resourceId = getResourceId(dataCollected)
                if resourceId is None:
                    print("WARN: Can not find resource Id for {}/{}:{}/{}.".format(objCat, dataCollected.get('MGMT_IP'), dataCollected.get('PORT'), dataCollected.get('NAME')))
                    continue

            if resourceId is None:
                resourceId = nodeInfo['resourceId']

            appSysIds = nodeInfo.get('appSystemId', None)
            defName = inspectDef['name']
            thresholdsMap = {}
            for ruleDef in inspectDef['thresholds']:
                thresholdsMap[ruleDef.get('ruleUuid', 0)] = ruleDef
                inspectRulesMap[ruleDef.get('ruleSeq')] = ruleDef

            for appDef in getAppMetrisDef(db, appSysIds, defName, thresholdsMap):
                for appRuleDef in appDef['thresholds']:
                    thresholdsMap[appRuleDef.get('ruleUuid', 0)] = appRuleDef
                    inspectRulesMap[appRuleDef.get('ruleSeq')] = appRuleDef

            rptData = {'RESOURCE_ID': nodeInfo['resourceId'],
                       'MGMT_IP': nodeInfo['host'],
                       '_execuser': execUser
                       }

            # 抽取巡检定义里需要的属性字段到报告中
            needFields = objCatFieldsMap[objCat]
            for needField in needFields:
                if needField in dataCollected:
                    rptData[needField] = dataCollected[needField]

            alerts = []
            for inspectRule in thresholdsMap.values():
                ast = inspectRule['AST']
                interpreter = Interpreter(AST=ast, ruleAppId=inspectRule.get('appSystemId'), ruleSeq=inspectRule['ruleSeq'], ruleName=inspectRule['name'], ruleLevel=inspectRule['level'], data=dataCollected)
                ruleAlerts = interpreter.resolve()
                alerts.extend(ruleAlerts)
                # alerts的数据结构
                # [
                #     {
                #         "jsonPath": "$.DISKS[0].CAPACITY",
                #         "ruleLevel": "WARN",
                #         "ruleAppId": 234315,
                #         "ruleSeq": "ABS#12",
                #         "ruleName": "测试",
                #         "fieldValue": 92,
                #     },
                #     {
                #         "jsonPath": "$.DISKS[1].CAPACITY",
                #         "ruleLevel": "WARN",
                #         "ruleAppId": 234315,
                #         "ruleSeq": "ABS#13",
                #         "ruleName": "测试",
                #         "fieldValue": 92,
                #     }
                # ]

            inspectResult = {'status': 'NORMAL',
                             'name': inspectDef['name'],
                             'label': inspectDef['label'],
                             'alerts': alerts}

            if 'port' in nodeInfo:
                rptData['PORT'] = nodeInfo['port']
            else:
                rptData['PORT'] = None

            alertThresHolds = {}
            if alerts:
                # 如果存在匹配了告警规则的rule，则存放到报告中，用于匹配规则的展示
                for alert in alerts:
                    ruleSeq = alert['ruleSeq']
                    alertThresHolds[ruleSeq] = inspectRulesMap[ruleSeq]
                inspectResult['hasAlert'] = True
            else:
                inspectResult['hasAlert'] = False

            if not availability:
                inspectResult['hasAlert'] = True
                alerts.append({
                    "jsonPath": "$.AVAILABILITY",
                    "ruleLevel": "CRITICAL",
                    'ruleSeq': '#01',
                    'ruleName': "可用性告警",
                    'fieldValue': availability
                })
                alertThresHolds['#01'] = {'name': '可用性告警',
                                          'level': 'CRITICAL',
                                          'ruleSeq': '#01',
                                          'rule': '$.AVAILABILITY != True'}

                if rptData.get('ERROR_MESSAGE'):
                    alerts.append({
                        "jsonPath": "$.ERROR_MESSAGE",
                        "ruleLevel": "CRITICAL",
                        'ruleName': "错误信息非空",
                        'ruleSeq': '#02',
                        'fieldValue': rptData.get('ERROR_MESSAGE')
                    })
                    alertThresHolds['#02'] = {'name': '错误信息非空',
                                              'level': 'CRITICAL',
                                              'ruleSeq': '#02',
                                              'rule': '$.ERROR_MESSAGE != ""'}

            inspectResult['thresholds'] = alertThresHolds

            alertSum = getAlertSummarized(alerts)
            alertSummarized = alertSum['alertSummarized']
            inspectResult['alertOutline'] = alertSum['alertOutline']
            inspectResult['alertFields'] = alertSum['alertFields']
            inspectResult['totalCount'] = alertSummarized['totalCount']
            inspectResult['status'] = alertSummarized['status']
            rptData['_inspect_result'] = inspectResult

            rptDataRecords.append(rptData)
    return rptDataRecords


def saveReport(db, ciType, resourceId, rptData, statusTarget='inspect'):
    rptTable = 'INSPECT_REPORTS'
    rptCollection = db[rptTable]

    currentTime = datetime.datetime.utcnow()
    rptData['_report_time'] = currentTime
    rptData['_jobid'] = os.getenv('AUTOEXEC_JOBID')
    primaryKey = {'RESOURCE_ID': rptData['RESOURCE_ID']}

    rptCollection.replace_one(primaryKey, rptData, upsert=True)
    rptCollection.create_index([('RESOURCE_ID', 1)], name='idx_pk')
    rptCollection.create_index([('_jobid', 1)], name='idx_jobid')
    print('INFO: Save report data success.')
    saveHisReport(db, ciType, resourceId, rptData)
    if statusTarget in ('inspect', 'both', 'all'):
        AutoExecUtils.updateInspectStatus(ciType, resourceId, rptData['_inspect_result']['status'], rptData['_inspect_result']['totalCount'])
        print('INFO: Update inspect status for node success.')
    if statusTarget in ('monitor', 'both', 'all'):
        AutoExecUtils.updateMonitorStatus(ciType, resourceId, rptData['_inspect_result']['status'], rptData['_inspect_result']['totalCount'])
        print('INFO: Update monitor status for node success.')


def saveHisReport(db, ciType, resourceId, rptData):
    rptTable = 'INSPECT_REPORTS_HIS'
    rptCollection = db[rptTable]

    rptCollection.insert_one(rptData)
    rptCollection.create_index([('RESOURCE_ID', 1), ('_report_time', 1)], name='idx_pk')
    #rptCollection.create_index([('_jobid', 1), ('_report_time', 1)], name='idx_jobid')
    rptCollection.create_index([('_report_time', 1)], name='idx_ttl', expireAfterSeconds=7776000)
    print('INFO: Save report history data success.')

# 反查资源中心resource_id


def getResourceId(jsonData):
    _OBJ_CATEGORY = jsonData['_OBJ_CATEGORY']
    objType = jsonData['_OBJ_TYPE']
    mgmtIp = jsonData['MGMT_IP']
    port = jsonData['PORT']

    name = jsonData.get('NAME')
    resourceId = None
    if (mgmtIp is None or mgmtIp) == '' and (name is None or name == ''):
        return resourceId
    resourceList = []
    if _OBJ_CATEGORY == "CONTAINER" or port is None:
        resourceList = AutoExecUtils.getResourceInfoList(mgmtIp, None, name, objType)
    else:
        resourceList = AutoExecUtils.getResourceInfoList(mgmtIp, port, name, objType)
    if len(resourceList) == 1:
        resourceId = resourceList[0]['id']
    return resourceId


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--outputfile', default='', help='Output json file path for node')
    parser.add_argument('--node', default='', help='Execution node json')
    parser.add_argument('--statustarget', default='inspect', help='Inspect status save target, all|inspect|monitor')
    args = parser.parse_args()
    outputPath = args.outputfile
    node = args.node
    statusTarget = args.statustarget

    dbclient = None
    try:
        nodeInfo = {}
        hasOptError = False
        if node is None or node == '':
            node = os.getenv('AUTOEXEC_NODE')
        if node is None or node == '':
            print("ERROR: Can not find node definition.\n")
            hasOptError = True
        else:
            nodeInfo = json.loads(node)

        if outputPath is None or outputPath == '':
            outputPath = os.getenv('NODE_OUTPUT_PATH')
        else:
            os.environ['NODE_OUTPUT_PATH'] = outputPath

        if outputPath is None:
            print("ERROR: Must set environment variable NODE_OUTPUT_PATH or defined option --outputfile.\n")
            hasOptError = True

        if hasOptError:
            usage()

        ciType = nodeInfo['nodeType']
        resourceId = nodeInfo['resourceId']
        (dbclient, db) = AutoExecUtils.getDB()

        rptDataRecords = inspectData(nodeInfo)
        for rptData in rptDataRecords:
            saveReport(db, ciType, resourceId, rptData, statusTarget)

    except Exception as ex:
        if 'password' in nodeInfo:
            nodeInfo['password'] = '******'
        print('ERROR: Unknow Error for node {}, {}'.format(nodeInfo, traceback.format_exc()))
        exit(-1)
    finally:
        if dbclient is not None:
            dbclient.close()
